# 数据集验证，需要加载的包库
import os
from PIL import Image
from matplotlib import pyplot as plt
import xml
import cv2 as cv
import numpy as np
import json
import math
from scipy.spatial import distance
from collections import OrderedDict
import time
import pandas as pds
from torch.utils.data import DataLoader
from torchvision import models
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimer
from logdir import loginfo
from .tool import calaccurary,visualize_cam
from torchsummary import summary
from utilour.gradcadpp import GradCAM
import torchvision
from torchvision.utils import make_grid, save_image

class ModelCalAnalys(object):
    def __init__(self,models,imgoutpath="."):
        """注册基本文件，将数据进行注册构建"""
        self.activateinfo=[]
        self.gradients=[]
        self.model=models
        self.imgoutpath=imgoutpath # 确定图片的输出文件路径
        def register_hook(module):
            def forward_hook(module,input_act,output):
                '''前向传播'''
                typecls=0
                if isinstance(module,nn.Sequential):
                    typecls=1
                elif isinstance(module, nn.ModuleList):
                    typecls=2
                class_name = str(module.__class__).split(".")[-1].split("'")[0]
                self.activateinfo.append({"type":typecls,"name":"{}_{}".format(class_name,len(self.activateinfo)),"input":[t.detach().cpu().numpy()  for t in input_act],"output":output.detach().cpu().numpy()})
            
            def backward_hook(module,grad_in,grad_out):
                '''后向传播'''
                typecls=0
                if isinstance(module,nn.Sequential):
                    typecls=1
                elif isinstance(module, nn.ModuleList):
                    typecls=2
                class_name = str(module.__class__).split(".")[-1].split("'")[0]
                self.gradients.append({"type":typecls,"name":"{}_{}".format(class_name,len(self.activateinfo)),"grad_in":grad_in.detach().cpu().numpy(),"grad_out":grad_out.detach().cpu().numpy()})
            
            module.register_forward_hook(forward_hook)
            module.register_backward_hook(backward_hook)
        
        # 绑定数据
        models.apply(register_hook)
    def forward(self,input_data,retain_graph=True,clsn=3):
        return self.model(input_data) # 返回模型执行结果

    def saveImage(self,slabel,tlabel,endtype="txt"):
        '''保存文件，确定文件的保存路径------方法，已经被废弃，需要消耗大量的内存进行存储，'''
        """
            这个默认是针对图像分类的标签进行处理
            如果你想针对图片的分割情况进行划分的话，请自行修改相关的代码，
            注意这里有关处理信息的代码，基本都是我针对特定情况进行设置，如果你想对这些数据继续扩充，其自行就修改相关的代码
            应该将每一层的计算结果，都放到一个长行图片中，最后再将所有的数据图片放到一个图片进行展示。
            多个批次，就存放多张图片。
            需要注意的是：cv的文件输出，有着文件大小的限制。
            注意：
            [[图片信息]，[特征图1],[特征图2],[特征图。。。。]] ----每一行的扩充图片，都是最后一张图片的重复使用
        """
        font=cv.FONT_HERSHEY_SIMPLEX#使用默认字体

        b,cs=tlabel.size()
        bimglist=[[] for i in range(b)]
        # 首先保存原始图片
        simgdata=0
        for layer in self.activateinfo:
            if layer["type"]==0:
                simgdata=layer["input"][0]
                break
        # 获取原始图片的信息 b:批次数。c:通道数。h:高，w:宽
        b,c,h,w=simgdata.shape
        # 生成图片的信息+原始图像
        for i in range(b):
            # 图片的信息
            tl=slabel.detach().cpu().numpy()[i]
            reslabel=tlabel.detach().cpu().numpy()[i,:]
            gl=np.argmax(reslabel)
            imginfo_cam=np.zeros((h,w,c),np.uint8) # 新建图像，注意一定要是uint8
            imginfo_cam=cv.putText(imginfo_cam,'{}-T-{}'.format(tl,gl),(0,50),font,1,(255,255,255),3)#添加文字，1.2表示字体大小，（0,40）是初始的位置，(255,255,255)表示颜色，2表示粗细
            imginfo_cam=imginfo_cam/255.0
            imginfo_cam=np.transpose(imginfo_cam,[2,0,1])
            # 初始化，保存原始图片
            bimglist[i].append([imginfo_cam,simgdata[i,:,:,:]]) #第一行保存图片信息+原始图片
        
        # 其次保存所有计算结果的图片
        for layer in self.activateinfo:
            # 保存计算信息
            # 计算层的信息
            finfo_img=np.zeros((h,w,c),np.uint8)
            finfo_img=cv.putText(finfo_img,'T:{}'.format(layer["type"]),(0,30),font,0.9,(255,255,255),3)
            finfo_img=cv.putText(finfo_img,'F:{}'.format(layer["name"]),(0,60),font,0.9,(255,255,255),3)
            finfo_img=cv.putText(finfo_img,'I:{}'.format(layer["input"][0].shape),(0,120),font,0.9,(255,255,255),3) 
            finfo_img=cv.putText(finfo_img,'O:{}'.format(layer["output"].shape),(0,180),font,0.9,(255,255,255),3) 
            finfo_img=finfo_img/255.0
            findo_img=torch.from_numpy(np.transpose(finfo_img,[2,0,1]))
            # 开始保存每次计算的大小，并缩放到合适的原始大小
            for i in range(b):
                #tempimgs=F.upsample(layer["output"],size=(h,w),mode="bilinear",align_corners=False)
                if len(layer["output"].shape)<4: # 这不是图片，是对应的分类信息
                    b,vs=layer["output"].shape
                    for v1 in range(vs):
                        imgs.append(layer["output"][i,v1])
                             # 这里暂时先解决图片的抽象处理问题，
                             # 并不解决对应的类似于encoding和decoding的激活情况的结构，如果想自行扩展，可以使用其他方式进行扩展
                b,k,h1,w1=layer["output"].shape
                imgs=[]
                for k1 in range(k):
                    imgs.append(layer["output"][i,k1,:,:])
                bimglist[i].append(imgs)
        # 打印数据并文件转换成合适的数据进行展示                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
        # 最后保存结果
        self.layercalResult=bimglist
        return bimglist
        pass